{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use SARSA($\\lambda$) to Play Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "00:00:00 [INFO] action_space: Discrete(6)\n",
      "00:00:00 [INFO] observation_space: Discrete(500)\n",
      "00:00:00 [INFO] reward_range: (-inf, inf)\n",
      "00:00:00 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "00:00:00 [INFO] _max_episode_steps: 200\n",
      "00:00:00 [INFO] _elapsed_steps: None\n",
      "00:00:00 [INFO] id: Taxi-v3\n",
      "00:00:00 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "00:00:00 [INFO] reward_threshold: 8\n",
      "00:00:00 [INFO] nondeterministic: False\n",
      "00:00:00 [INFO] max_episode_steps: 200\n",
      "00:00:00 [INFO] _kwargs: {}\n",
      "00:00:00 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSALambdaAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.99\n",
    "        self.learning_rate = 0.1\n",
    "        self.epsilon = 0.01\n",
    "        self.lambd = 0.6\n",
    "        self.beta = 1.\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "            self.e = np.zeros(self.q.shape)\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = self.q[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, next_action = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        # update eligibility trace\n",
    "        self.e *= (self.lambd * self.gamma)\n",
    "        self.e[state, action] = 1. + self.beta * self.e[state, action]\n",
    "\n",
    "        # update value\n",
    "        target = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = target - self.q[state, action]\n",
    "        self.q += self.learning_rate * self.e * td_error\n",
    "\n",
    "\n",
    "agent = SARSALambdaAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] ==== train ====\n",
      "00:01:30 [INFO] ==== test ====\n",
      "00:01:30 [INFO] average episode reward = 7.44 ± 2.18\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZEUlEQVR4nO3deXxV1b338c8vcwhJCCQhIQkkQJjLGCZFEQgSHIr12iu2Vltui7W2Sm3rI9LawfJoffrcW31a9VKnelsvtdpbvSra4vRoy2BwZFSQKYISJpnndf84m+QknATCSXIC6/t+vc4r+/zW3mevs5J8s8/eK+eYcw4REfFLXKw7ICIirU/hLyLiIYW/iIiHFP4iIh5S+IuIeCgh1h04VdnZ2a64uDjW3RAROaMsWbJkq3Mup379jAn/4uJiKisrY90NEZEzipmtj1TXaR8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHx0Bkzzz8ar31QTbeO7Xjk72t5/+PP6J2Xzobt+8hNT2HD9n1069gOgBWf7GbF5l0APPnN0SxYs40LeufSPiWBxxetZ0Lfznzjd5UM7ZZFYrxxfq8c/mPBesb3yaVrp3a8tX4nT71VxXfG9yQzNZEeue3567JPmTq8iCeXVPHqB1uo3n2Q0tx0Zk7uw95DR3lrww565LTntQ+qGVSYyfLNuzinRzavrNpC5/QUlm36jEVrt1PQIZW++el07ZjGoKJMbpr7DrdO7gPA0K5Z9MhJY9wvX+XigV3olJbEqx9sYdq5JQwszKRdUgJPLaliyYYdjOudS256Mtf/4S1mlJdy5KijR24a3/3ju5T37cy6bXtZvWUPeRkp/PjSfnRol8SPn1lKxYB8lm/axZFjx2r2e+SoIz8zhUm/ep1HvzacRWu3kxRvDO2WxVVzFjKhb2dWbN7Fo18bwQtLN/PJroOU5rbn16+sJqd9MovXbad7ThoT+uTy29fXcumgLpT3zWXNlj289/FnTB6QR25GCj/6y1I6tEvk2+NK+Wz/Ifp3yeTZ9zbTKS2Jh/++loyURG4Y35Mb//Ntnr7hXF7/sJpJ/fNCY5iRwh/f3EheRgrb9h5iUGEmkwbk8crKLXzw6R4yUxN5Y/VWOqYlcWVZEYO7duCRv69jcFEmo7tn8+AbH/HYgvWUdcvikoH5/OS/l3Nuz06M7ZXDi8s+ZWyvHLbsPsDRY9CvSwZvb9hBn7x0Vn2yh2HdsmiXFM/G7ft4fPEGbqnozfa9h7nj2eX8+NJ+XD6kkOn/Ucklg7qQEGfMfm4FT14/mh457Xl3406y0pL44gML2L73EJMH5PH9Sb35eMd+8jJTeHvDDtJTEokzY/bzy2mXmECvvHQyUxN47r3NVAzIJzHemDwgnz9VbuTPb38MwBXDCjmvNJuUxHh27D3Ek0uquGxIAQcOH+XTXQcYXJRFQVYqO/Ye4q/LP2FkSSdSEuO4c95K1m/bV/M71b9LBqO6d2JMaTbb9xwiMSGO+15ZzdjeObyw9BOGdcuifXICT1Ru5MDhY5zbsxOXDOzC1t0H2bn/MKmJ8azdtpe0pHji44zc9BTueelDbp7Yi/2Hj9I9O42d+w7zjzVb+Xjnfs4rzeGhN9Yyorgji9dt5/IhBSTGx/H8+5vZffAIyQlxfHlkN3buOwTA1r2H+Oo53Thy1PHoP9bxzsadzL95LHfOW0l+Zgr/WLOVL4/sxu4Dh+nSIZXR3Tvx9zXb2LbnIM7Bz55dzo3je1LerzNXP7iIf7tyMJs+O8DHO/bz4rJPmDwgj217DrF2616GFWexa/9h+uRncNfzK7hyeFe6dEjhk88O0DO3PR9t3cuz727i0sFd6NguiTvnraR353SmjSnmo+q9jOrRiV/N/5B3N+5kTM9sqncfJCkhjpsn9uLA4aOUFXcku30SZtasuWixej9/M6sA7gHigQedc3c1tn5ZWZk7nX/yemHpZr75+7dOr5MiIm3AyjsqSEmMP61tzWyJc66sfj0mp33MLB74DTAZ6AdcZWb9WmJfCz/a3hIPKyLSauKa+agfYnfOfwSw2jn3kXPuEDAXmBKjvoiItGkJcWdP+BcAG8PuVwW1Zle1Y39LPKyISKuJO4vCP9IzOeHig5lNN7NKM6usrq4+rR3NX/HpaW0nInI2i1X4VwFFYfcLgU31V3LOzXHOlTnnynJyTnhH0mbzf784iN6d05l1UV9y05P56ef7M2Vwl1PatqBDKhP65HL3FQP5waTeddpKstO4Z+rgE7bplJZETnpynVpGSkLNLBqAwqxUfvOloYwo7shfbjiXN2eV17QNKurAxH6d62z/0vfGnvB443rn8Oascv7/D8Yxunsnvn9hL7pnpwHwi3/6HEt+WM4dU/qf0L/fXlPGmJ7ZAAzt2oGsdok8+c3R/PKLg8hNT6Z35/Sada8b252RJR0jjk1JdhpPXDcagPNKsyOuczL/XFZYs9yhXSL9u2TQq3N7Hrh6KDPKS7lubHf++9tjuLDeeNT3lVHdapZLc9tHXOfCfp3p1TnU9ti0Eay98yLe+8mFXD2qK8OLs05YPyMlge45aXVqz994HrMu6sv//sLnACJud9xj00bwwNVD69Suv6AHlw8p4Lrzu1OYlUpF/zzuuGwA7ZLi+cPXR7J69mS+MKSAoo6pDT7++b1ymDt9FA9dW8bsLwyoqX/1nOKa5a+dW1xnmx9M6s1bP5rI/V+u7U9qYjxLfhj6ubtubPc66xd0SK1Z7pefwXfLezGypCN98zP40siuTB1exJVlRSz/2STuvPxzdbYtzW3PI18bXnM/JTGOdXddzHfLe/GVUd147ycX8q0LetS0f+O8Epb+dBKv3zKO74zvCcCPL629RNi1Y7ua3717pg5mcFGHE8YE4NGwfQJcPaorALdfcuLlxhnlpXV+Zo67fEgBP7m07vojSjqy6LYJjOudw1dGdeOl743luRvHsGDmeG4M+vvCjPPq/CyHu2lCKXEWek63Tu7D8YP8hDijU1pSxG2aQ0xm+5hZAvABMAH4GHgT+JJzbllD25zubJ/iW59rsO32S/rx+OINzL95bMT28+9+hQ3b9/Hq9y+gODut5rHW3XVxk/tx9wsrue/VNQA8/o2RnNOj6WFYf//H7+emJ7N4VjlVO/aRFB9HYnwcWU38oan/2L9fuJ4f/mUpM8pLmVHe64T1jxw9xtPvbOILQwqIizMm/utrfLhlD2//aGKT933cg69/xM+fW1GnHwCls57n8FHHqp9XkJzQ+IyHA4ePsm3vIc696+U69XV3Xcxd81bywGtrWP6zSbRLSqh5zo9/fSRFHdtRFEz5bcjRY46qHfvo1imt0fWa6tCRY3y668BJ93+6jh1zrN++j5LsE/sd6We6sedZfOtzXDGskF9+cVCT+uCcY+3WvXTPifyHN5KPqvc0uv6a6j08taSKH0zqfcI0yAVrtnHVbxfW3D+V39kxv3iZqh37WXzbBHIzUmrqmz/bz+g7X6ZzRjKLbgv9MTx2zPHz51bw9fNK6BL2h/BUvbzyUwZ0yayzn5bS0GyfmMzzd84dMbNvAy8Smur5cGPB31KmjSlh2piSBtv/31VD+PUrqynMavo3t74rhhVy36treOSrw08r+AEe+epwMlITa+7/y5gSHnpjLTdOKAWgMKv5wuPK4UXsPnDkhCPE4xLi4/inYbVHMr+bNoJXV1WfdvBD6Mh0z8EjXHd+jzr1300bwdzFG0mKP/kL1ZTEeAo6pNIlM4X+BZnMuqgvyYmh7W6Z1JsZ5aU1U+ZumlBKdnoy5/Q8te9HfJw1e/ADJCXEtVjwQ+h8caTgB3j5e2Npn1w3Bhp7nivvqDil70N9Ztak4AdOun6PnPbcUtEnYtvxV3DfuqAHAwszT2l/r98yjgOHj5GaVPcAIzc9hckD8vj6ebWvfuLijNsvPf0JiuP7NP5KtTXEbJ5/UzX3kf/087tz20V9T+uxTufIv617r2on7ZLi6ZmbfvKVReSM0aaO/GPt+GkcqTWwsEOsuyAircjL9/ZR8IuI77wM/2i05NV3EZHW4t1pnz55p39O+8/fOqdZLv6KiMSad+GfeBozFY4b2rXhOdsiImcS7077tMD7I4mInHH8C/9Yd0BEpA3wLvx16C8i4mH4K/pFRDwM/4berkBExCfehf+UwS3ysQEiImcU78JfREQU/iIiXlL4i4h4SOEvIuIhhb+IiIcU/iIiHlL4i4h4SOEvIuIhhb+IiIcU/iIiHlL4i4h4SOEvIuIhhb+IiIcU/iIiHlL4i4h4SOEvIuIhhb+IiIcU/iIiHlL4i4h4SOEvIuIhhb+IiIcU/iIiHlL4i4h4SOEvIuKhqMLfzP6Pma00s/fM7L/MrENY20wzW21mq8xsUlh9mJm9H7Tda2YWTR9ERKTpoj3y/xswwDk3EPgAmAlgZv2AqUB/oAK4z8zig23uB6YDpcGtIso+iIhIE0UV/s65vzrnjgR3FwKFwfIUYK5z7qBzbi2wGhhhZvlAhnNugXPOAY8Bl0XTh6bITE1srV2JiLRpzXnOfxowL1guADaGtVUFtYJguX69VYT+3oiISMLJVjCz+UBehKZZzrmng3VmAUeAPxzfLML6rpF6Q/ueTugUEV27dj1ZV0VE5BSdNPydc+WNtZvZtcAlwARXe2hdBRSFrVYIbArqhRHqDe17DjAHoKysTIftIiLNJNrZPhXA/wI+75zbF9b0DDDVzJLNrITQhd3FzrnNwG4zGxXM8rkGeDqaPjSF/nqIiISc9Mj/JH4NJAN/C2ZsLnTOfdM5t8zMngCWEzoddINz7miwzfXAo0AqoWsE80541Jai9BcRAaIMf+dcz0baZgOzI9QrgQHR7FdERKLj1X/46sBfRCTEr/DXVE8REcCz8BcRkRCFv4iIh7wKf530EREJ8Sr8RUQkxKvw1/VeEZEQr8JfRERCvAp/p7P+IiKAZ+F/TNkvIgJ4Fv4iIhKi8BcR8ZDCX0TEQ36Fv875i4gAvoW/iIgAnoW/pnqKiIT4Ff7KfhERwLfwj3UHRETaCK/CX0REQrwKf32Sl4hIiFfhLyIiIV6F/7XnFMe6CyIibYJX4X/zxF6x7oKISJvgVfinJSXEugsiIm2CV+EfF2ex7oKISJvgVfiLiEiIwl9ExEMKfxERDyn8RUQ8pPAXEfGQwl9ExEMKfxERDyn8RUQ8pPAXEfGQwl9ExEPNEv5m9n0zc2aWHVabaWarzWyVmU0Kqw8zs/eDtnvNTO+5ICLSyqIOfzMrAiYCG8Jq/YCpQH+gArjPzOKD5vuB6UBpcKuItg8iItI0zXHk/2/ALdT9iNwpwFzn3EHn3FpgNTDCzPKBDOfcAhf6WK3HgMuaoQ8iItIEUYW/mX0e+Ng59269pgJgY9j9qqBWECzXrzf0+NPNrNLMKqurq6PpqoiIhDnpG9yb2XwgL0LTLOA24MJIm0WouUbqETnn5gBzAMrKyvQBvCIizeSk4e+cK49UN7PPASXAu8E120LgLTMbQeiIvihs9UJgU1AvjFAXEZFWdNqnfZxz7zvncp1zxc65YkLBPtQ59wnwDDDVzJLNrITQhd3FzrnNwG4zGxXM8rkGeDr6pyEiIk3RIp9r6JxbZmZPAMuBI8ANzrmjQfP1wKNAKjAvuImISCtqtvAPjv7D788GZkdYrxIY0Fz7FRGRptN/+IqIeEjhLyLiIYW/iIiHFP4iIh5S+IuIeEjhLyLiIYW/iIiHFP4iIh5S+IuIeEjhLyLiIYW/iIiHFP4iIh5S+IuIeEjhLyLiIW/CPy0pPtZdEBFpM7wJ/3F9cmPdBRGRNsOb8P/ZFH1+jIjIcd6Ef/vkFvnEShGRM5I34S8iIrUU/iIiHlL4i4h4yJvwN4t1D0RE2g5vwl9ERGop/EVEPKTwFxHxkMJfRMRD3oS/rveKiNTyJvxFRKSWwl9ExEMKfxERDyn8RUQ8pPAXEfGQwl9ExEMKfxERDyn8RUQ8FHX4m9l3zGyVmS0zs7vD6jPNbHXQNimsPszM3g/a7jVrnffbbKXdiIicEaL6bEMzGwdMAQY65w6aWW5Q7wdMBfoDXYD5ZtbLOXcUuB+YDiwEngcqgHnR9ENERJom2iP/64G7nHMHAZxzW4L6FGCuc+6gc24tsBoYYWb5QIZzboFzzgGPAZdF2QcREWmiaMO/F3CemS0ys9fMbHhQLwA2hq1XFdQKguX69YjMbLqZVZpZZXV1dZRdFRGR40562sfM5gN5EZpmBdtnAaOA4cATZtadyO+j5hqpR+ScmwPMASgrK2twPRERaZqThr9zrryhNjO7HvhzcApnsZkdA7IJHdEXha1aCGwK6oUR6i0uPk4XfEVEjov2tM9fgPEAZtYLSAK2As8AU80s2cxKgFJgsXNuM7DbzEYFs3yuAZ6Osg8iItJEUc32AR4GHjazpcAh4NrgVcAyM3sCWA4cAW4IZvpA6CLxo0AqoVk+mukjItLKogp/59wh4OoG2mYDsyPUK4EB0exXRESio//wFRHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ8p/EVEPKTwFxHxkMJfRMRDCn8REQ9FFf5mNtjMFprZO2ZWaWYjwtpmmtlqM1tlZpPC6sPM7P2g7V4zs2j6ICIiTRftkf/dwE+dc4OB24P7mFk/YCrQH6gA7jOz+GCb+4HpQGlwq4iyDyIi0kTRhr8DMoLlTGBTsDwFmOucO+icWwusBkaYWT6Q4Zxb4JxzwGPAZVH2QUREmighyu1nAC+a2S8J/SE5J6gXAAvD1qsKaoeD5fr1iMxsOqFXCXTt2jXKroqIyHEnDX8zmw/kRWiaBUwAvuuce8rM/hl4CCgHIp3Hd43UI3LOzQHmAJSVlTW4noiINM1Jw985V95Qm5k9BtwU3P0T8GCwXAUUha1aSOiUUFWwXL8uIiKtKNpz/puAscHyeODDYPkZYKqZJZtZCaELu4udc5uB3WY2Kpjlcw3wdJR9EBGRJor2nP83gHvMLAE4QHB+3jm3zMyeAJYDR4AbnHNHg22uBx4FUoF5wU1ERFpRVOHvnHsDGNZA22xgdoR6JTAgmv2KiEh09B++IiIeUviLiHhI4S8i4iGFv4iIhxT+IiIeUviLiHhI4S8i4iGFv4iIhxT+IiIeUviLiHhI4S8i4iGFv4iIhxT+IiIeivYtndu833xpKGnJ8SdfUUTEI2d9+F88MD/WXRARaXN02kdExEMKfxERDyn8RUQ8pPAXEfGQwl9ExEMKfxERDyn8RUQ8pPAXEfGQOedi3YdTYmbVwPrT3Dwb2NqM3TmTaSxqaSxqaSxCzsZx6Oacy6lfPGPCPxpmVumcK4t1P9oCjUUtjUUtjUWIT+Og0z4iIh5S+IuIeMiX8J8T6w60IRqLWhqLWhqLEG/GwYtz/iIiUpcvR/4iIhJG4S8i4qGzOvzNrMLMVpnZajO7Ndb9aS5m9rCZbTGzpWG1jmb2NzP7MPiaFdY2MxiDVWY2Kaw+zMzeD9ruNTML6slm9segvsjMilv1CTaBmRWZ2StmtsLMlpnZTUHdq/EwsxQzW2xm7wbj8NOg7tU4hDOzeDN728yeDe57OxYROefOyhsQD6wBugNJwLtAv1j3q5me2/nAUGBpWO1u4NZg+VbgF8Fyv+C5JwMlwZjEB22LgdGAAfOAyUH9W8ADwfJU4I+xfs6NjEU+MDRYTgc+CJ6zV+MR9Ll9sJwILAJG+TYO9cbkZuBx4NngvrdjEXF8Yt2BFvzGjwZeDLs/E5gZ63414/Mrrhf+q4D8YDkfWBXpeQMvBmOTD6wMq18F/Hv4OsFyAqH/eLRYP+dTHJengYk+jwfQDngLGOnrOACFwEvA+LDw93IsGrqdzad9CoCNYfergtrZqrNzbjNA8DU3qDc0DgXBcv16nW2cc0eAz4BOLdbzZhK89B5C6KjXu/EITnO8A2wB/uac83IcAr8CbgGOhdV8HYuIzubwtwg1H+e1NjQOjY3PGTd2ZtYeeAqY4Zzb1diqEWpnxXg454465wYTOuodYWYDGln9rB0HM7sE2OKcW3Kqm0SonRVj0ZizOfyrgKKw+4XAphj1pTV8amb5AMHXLUG9oXGoCpbr1+tsY2YJQCawvcV6HiUzSyQU/H9wzv05KHs7Hs65ncCrQAV+jsO5wOfNbB0wFxhvZr/Hz7Fo0Nkc/m8CpWZWYmZJhC7KPBPjPrWkZ4Brg+VrCZ37Pl6fGsxOKAFKgcXBy97dZjYqmMFwTb1tjj/WFcDLLji52dYEfX8IWOGc+9ewJq/Gw8xyzKxDsJwKlAMr8WwcAJxzM51zhc65YkK/9y87567Gw7FoVKwvOrTkDbiI0OyPNcCsWPenGZ/XfwKbgcOEjkD+hdD5xpeAD4OvHcPWnxWMwSqC2QpBvQxYGrT9mtr/+E4B/gSsJjTboXusn3MjYzGG0Mvt94B3gttFvo0HMBB4OxiHpcDtQd2rcYgwLhdQe8HX67Gof9PbO4iIeOhsPu0jIiINUPiLiHhI4S8i4iGFv4iIhxT+IiIeUviLiHhI4S8i4qH/AU3RWpzYP7jiAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env.spec.max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > env.spec.reward_threshold:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
